from quant.const import OrderStatus, Offset
from quant.utils import logging, InfluxClient
from pricing import FairCalculator, SpreadCalculator
from strategies.utils import create_adaptor, Executor, InfoWriter


class StandardMaker:
    def __init__(self, markets, accounts, config):
        self.markets = markets
        self.config = config
        self.variables = Variables(config)                               # 集中存放参数

        self.fair_cal = FairCalculator(config['fair_amount'])            # 微观价格计算器
        self.spread_cal = SpreadCalculator(config['spread_amount'])      # 做市价差计算器
        self.buy_opener = Executor(self, config, 'buy', Offset.Open)     # 买入开仓执行器
        self.buy_closer = Executor(self, config, 'buy', Offset.Close)    # 买入平仓执行器
        self.sell_opener = Executor(self, config, 'sell', Offset.Open)   # 卖出开仓执行器
        self.sell_closer = Executor(self, config, 'sell', Offset.Close)  # 卖出平仓执行器

        key = config['key']
        self.acc = accounts.create_account(config['exchange'], key)
        self.acc.add_symbol(config['symbol'])
        self.api = self.acc.api
        
        self.adaptor = create_adaptor(config['maker_type'], self)        # 适配器，用于通用化现货、合约策略。maker_type指定现货('spot')或合约('future')
        self.adaptor.query(wait=True)                                    # 初始化查询。现货、合约查询内容不同，故在adaptor.query()中分别定义

        markets.timer.register(self.adaptor.query, 5)
        self.info_writer = self.init_info_writer()                       # 日志写入模块，将交易信息写入influx数据库

    def init_info_writer(self):
        var = self.variables
        db_conf = self.config['database']

        client = InfluxClient(db_conf['host'], db_conf['port'], db_conf['username'], db_conf['password'])
        info_writer = InfoWriter(self, client)

        info_writer.start_write_pricing('fair', self.fair_cal)
        info_writer.start_write_pricing('spread', self.spread_cal)
        info_writer.start_write_mid(var.exchange, var.symbol)
        info_writer.start_write_orders()
        info_writer.start_write_execution()
        info_writer.start_write_deal()
        info_writer.start_write_position()
        info_writer.start_write_profit()
        info_writer.start_write_asset()
        info_writer.start_write_logging()
        info_writer.start_write_api_error()
        return info_writer

    def start(self):
        to_add = self.config['market_data']
        for event, exchange, symbol in to_add:
            self.markets.add_market(event, exchange, symbol)
        self.markets.subscribe_all(self.on_market_data)

    def on_market_data(self, market_data):
        book = market_data.data
        if book is None:
            self.pause()
            logging.info('接口已断开，等待自动重连')
            return

        fair = self.fair_cal.calcualte(book)
        spread = self.spread_cal.calculate(book)
        if fair is None or spread is None:
            self.pause()  # 盘口深度不足，计算失败，暂停策略
            return

        pos = self.adaptor.position()
        buy_open_amt, buy_close_amt, sell_open_amt, sell_close_amt = self.cal_amt(pos)
        buy_open_prc, buy_close_prc, sell_open_prc, sell_close_prc = self.cal_prc(fair, spread)

        self.buy_opener.execute(buy_open_prc, buy_open_amt)
        self.buy_closer.execute(buy_close_prc, buy_close_amt)
        self.sell_opener.execute(sell_open_prc, sell_open_amt)
        self.sell_closer.execute(sell_close_prc, sell_close_amt)

    def pause(self):
        orders = self.acc.orders
        for o in orders.values():
            self.try_cancel(o)

    def cal_prc(self, fair, spread):
        close_multiply = self.variables.close_spread_multiply
        buy_open = fair - spread
        buy_close = fair - spread*close_multiply
        sell_open = fair + spread
        sell_close = fair + spread*close_multiply
        return buy_open, buy_close, sell_open, sell_close

    def cal_amt(self, pos):
        amt = self.variables.amount
        max_ = self.variables.max_pos

        buy_open = amt
        buy_close = -pos
        sell_open = amt
        sell_close = pos

        if pos + amt > max_:
            buy_open = max_ - pos
        if pos - amt < -max_:
            sell_open = max_ - (-pos)

        return buy_open, buy_close, sell_open, sell_close

    def try_cancel(self, order):
        if order.status in (OrderStatus.Placing, OrderStatus.Canceling):
            return
        self.acc.api.cancel_order(order)


class Variables:
    def __init__(self, config):
        self.exchange = config['exchange']
        self.symbol = config['symbol']
        self.amount = config['amount']
        self.max_pos = config['max_pos']        # 最大仓位
        self.hold = config['hold']              # 现货参数，底仓，合约config中可填写None
        self.asset = self.symbol.split('/')[0]  # 品种asset，例如'btc/usdt'交易对则为'btc'
        self.close_spread_multiply = config['close_spread_multiply']  # 平仓单激进乘数

        database = config['database']
        self.database_name = database['database_name']
        self.log_id = database['log_id']
